{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Impact of 401(k) eligibility on net financial assets\n",
    "\n",
    "In this case study, we will use real-world data from 401(k) analysis to explain how Causality library can be used to estimate average treatment effect (ATE) and conditional ATE (CATE)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Background\n",
    "\n",
    "In the early 1980s, the United States government introduced several tax deferred savings options for employees in an effort to increase individual saving for retirement. One popular option is the 401(k) plan, which allows employees to contribute a portion of their wages to individual accounts. The goal here is to understand the effect of 401(k) eligibility on net financial assets (which is a sum of 401(k) balances and non-401(k) assets) considering heterogeneity due to individual's characteristics (income in particular).\n",
    "\n",
    "Since 401(k) plans are provided by employers, only employees of companies that offer those plans are eligible for participation. As such, we are dealing with a non-randomized study. Several factors (e.g. education, preference for saving) affect 401(k) eligibility as well as net financial assets."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data\n",
    "\n",
    "We consider a sample from the Survey of Income and Program Participation in 1991. The sample consists of households where the reference individual is 25-64 years old, and at least one individual is employed but no one is self-employed. There are records of 9915 households in the sample. For each household, 44 variables are recorded that include the eligibility of the household reference person for the 401(k) plan (the treatment), net financial assets (the outcome), and other covariates, such as age, income, family size, education, marital status, etc. We consider 16 covariates in particular.\n",
    "\n",
    "We summarise the variables used for this case study in the table below.\n",
    "\n",
    "| Variable Name | Type | Details |\n",
    "| --- | --- | --- |\n",
    "|e401|Treatment|eligibility for the 401(k) plan|\n",
    "|net_tfa|Outcome|net financial assets (in USD)|\n",
    "|age|Covariate|Age|\n",
    "|inc|Covariate|income (in USD)|\n",
    "|fsize|Covariate|family size|\n",
    "|educ|Covariate|education (in years)|\n",
    "|male|Covariate|is a male?|\n",
    "|db|Covariate|defined benefit pension|\n",
    "|marr|Covariate|married?|\n",
    "|twoearn|Covariate|two earners|\n",
    "|pira|Covariate|participation in IRA|\n",
    "|hown|Covariate|home owner?|\n",
    "|hval|Covariate|home value (in USD)|\n",
    "|hequity|Covariate|home equity (in USD)|\n",
    "|hmort|Covariate|home mortgage (in USD)|\n",
    "|nohs|Covariate|no high-school? (one-hot encoded)|\n",
    "|hs|Covariate|high-school? (one-hot encoded)|\n",
    "|smcol|Covariate|some-college? (one-hot encoded)|\n",
    "\n",
    "\n",
    "The dataset is publicly available online from the [`hdm`](https://rdrr.io/cran/hdm/man/pension.html) R package. For more details about the data set, we refer the interested reader to the following paper:\n",
    "\n",
    "V. Chernohukov, C. Hansen (2004). [The impact of 401(k) participation on the wealth distribution: An instrumental quantile regression analysis](http://www.mit.edu/~vchern/papers/ch_401k.pdf). The Review of Economic and Statistics 86 (3), 735–751. \n",
    "\n",
    "Let's load and analyse the data first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_csv(\"pension.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Effect of 401(k) Eligibility on Net Financial Assets, Conditioned on Income\n",
    "\n",
    "First we construct a causal graph of 401(k) plan eligibility (the treatment $T$), net financial assets (the outcome $Y$), control variables $W$ we adjust for assuming that they block all back-door paths between $Y$ and $T$, and income $X$ (the covariate of interest based on which we want to study the heterogeneity of treatment effect)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import dowhy.gcm as gcm\n",
    "\n",
    "treatment_var = \"e401\"\n",
    "outcome_var = \"net_tfa\"\n",
    "covariates = [\"age\",\"inc\",\"fsize\",\"educ\",\"male\",\"db\",\n",
    "              \"marr\",\"twoearn\",\"pira\",\"hown\",\"hval\",\n",
    "              \"hequity\",\"hmort\",\"nohs\",\"hs\",\"smcol\"]\n",
    "\n",
    "edges = [(treatment_var, outcome_var)]\n",
    "edges.extend([(covariate, treatment_var) for covariate in covariates])\n",
    "edges.extend([(covariate, outcome_var) for covariate in covariates])\n",
    "\n",
    "causal_graph = nx.DiGraph(edges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gcm.util.plot(causal_graph, figure_size=[20, 20])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we created a simplified graph where there are no interactions between covariates (i.e. nodes in $X \\cup W$). Most likely, that is not the case in practice. However, as we take joint samples of the covariates—directly from the observed data—later to estimate CATEs, we can ignore their interactions. \n",
    " \n",
    "Before we assign causal models to variables, let's plot their histograms to get an idea on the distribution of variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "cols = [treatment_var, outcome_var]\n",
    "cols.extend(covariates)\n",
    "plt.figure(figsize=(10,5))\n",
    "for i, col in enumerate(cols):\n",
    "    plt.subplot(3,6,i+1)\n",
    "    plt.grid(False)\n",
    "    plt.hist(df[col])\n",
    "    plt.xlabel(col)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We observe that real-valued variables do not follow well-known parameteric distributions like Gaussian. Therefore, we fit empirical distributions whenever those variables do not have parents, which is also suitable for categorical variables. \n",
    "\n",
    "Let's assign causal models to variables. For the treatment variable, we assign a classifier functional causal model (FCM) with a random forest classifier. For the outcome variable, we assign an additive noise model with a random forest regression as a function and empirical distribution for the noise. We assign empirical distributions to other variables as they do not have parents in the causal graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "causal_model = gcm.StructuralCausalModel(causal_graph)\n",
    "causal_model.set_causal_mechanism(treatment_var, gcm.ClassifierFCM(gcm.ml.create_random_forest_classifier()))\n",
    "causal_model.set_causal_mechanism(outcome_var, gcm.AdditiveNoiseModel(gcm.ml.create_random_forest_regressor()))\n",
    "for covariate in covariates:\n",
    "    causal_model.set_causal_mechanism(covariate, gcm.EmpiricalDistribution())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To fit a classifier FCM, we cast the treatment column to string type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.astype({treatment_var: str})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert-info\">\n",
    "Instead of assigning the models manually, we can also automate this **if** we don't have prior knowledge or are not familiar with the statistical implications:\n",
    "    \n",
    "> gcm.auto.assign_causal_mechanisms(causal_model, df)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With that, we can now fit the learn the causal models from data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gcm.fit(causal_model, df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before computing CATE, we first divide households into equi-width bins of income percentiles. This allows us to study the impact on various income groups."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "percentages = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]\n",
    "bin_edges = [0]\n",
    "bin_edges.extend(np.quantile(df.inc, percentages[1:]).tolist())\n",
    "bin_edges[-1] += 1 # adding 1 to the last edge as last edge is excluded by np.digitize\n",
    "\n",
    "groups = [f'{percentages[i]*100:.0f}%-{percentages[i+1]*100:.0f}%' for i in range(len(percentages)-1)]\n",
    "group_index_to_group_label = dict(zip(range(1, len(bin_edges)+1), groups))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can compute CATE. To this end, we perform a randomised intervention on the treatment variable in the fitted causal graph, draw samples from the interventional distribution, group observations by the income group, and then compute the treatment effect in each group. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(47)\n",
    "\n",
    "def estimate_cate():\n",
    "    samples = gcm.interventional_samples(causal_model, \n",
    "                                         {treatment_var: lambda x: np.random.choice(['0', '1'])},\n",
    "                                         observed_data=df)\n",
    "    eligible = samples[treatment_var] == '1'\n",
    "    ate = samples[eligible][outcome_var].mean() - samples[~eligible][outcome_var].mean()\n",
    "    result = dict(ate = ate)\n",
    "    \n",
    "    group_indices = np.digitize(samples['inc'], bin_edges)\n",
    "    samples['group_index'] = group_indices\n",
    "    \n",
    "    for group_index in group_index_to_group_label:\n",
    "        group_samples = samples[samples['group_index'] == group_index]\n",
    "        eligible_in_group = group_samples[treatment_var] == '1'\n",
    "        cate = group_samples[eligible_in_group][outcome_var].mean() - group_samples[~eligible_in_group][outcome_var].mean()\n",
    "        result[group_index_to_group_label[group_index]] = cate\n",
    "        \n",
    "    return result\n",
    "\n",
    "group_to_median, group_to_ci = gcm.confidence_intervals(estimate_cate, num_bootstrap_resamples=100)\n",
    "print(group_to_median)\n",
    "print(group_to_ci)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The average treatment effect of 401(k) eligibility on net financial assets is positive as indicated by the confidence interval $[4902.24, 8486.89]$. Now, let's plot CATEs of various income groups to get a clear picture."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(8,4))\n",
    "for x, group in enumerate(groups):\n",
    "    ci = group_to_ci[group]\n",
    "    plt.plot((x, x), (ci[0], ci[1]), 'ro-', color='orange')\n",
    "ax = fig.axes[0]\n",
    "ax.spines['right'].set_visible(False)\n",
    "ax.spines['top'].set_visible(False)\n",
    "plt.xticks(range(len(groups)), groups)\n",
    "plt.xlabel('Income group')\n",
    "plt.ylabel('ATE of 401(k) eligibility on net financial assets')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The impact increases as one moves from lower to higher income groups. This result seems to be consistent with the resource constraints of the different income groups."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}